{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import tiktoken\n",
    "import datasets\n",
    "import langdetect\n",
    "from semantic_text_splitter import TextSplitter\n",
    "from string import Template\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# * load dataset from jsonlines file\n",
    "dataset = datasets.load_dataset(\"json\", data_files=\"raw_data/together-long/arxiv.json\", split=\"train\")\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# * set index for each sample\n",
    "dataset = dataset.map(lambda x, index: {\"index\": index}, with_indices=True, num_proc=32)\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# * filter data by length\n",
    "enc = tiktoken.encoding_for_model(\"gpt-4\")\n",
    "\n",
    "def filter_length(examples):\n",
    "    res = []\n",
    "    for text in examples[\"text\"]:\n",
    "        try:\n",
    "            token_len = len(enc.encode(text))\n",
    "        except:\n",
    "            res.append(False)\n",
    "            continue\n",
    "        if token_len < 32_000:\n",
    "            res.append(False)\n",
    "        elif token_len > 80_000:\n",
    "            res.append(False)\n",
    "        else:\n",
    "            res.append(True)\n",
    "\n",
    "    return res\n",
    "\n",
    "\n",
    "dataset = dataset.filter(filter_length, batched=True, num_proc=32)\n",
    "\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# * filter non-English data\n",
    "dataset = dataset.filter(lambda x: langdetect.detect(x[\"text\"]) == \"en\", num_proc=32)\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# * random sample\n",
    "dataset = dataset.train_test_split(test_size=1_000, seed=2024)[\"test\"]\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# * save data as the backup\n",
    "dataset.to_json(\"backup_data/one_detail.paper.long.jsonl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = datasets.load_dataset(\"json\", data_files=\"backup_data/one_detail.paper.long.jsonl\", split=\"train\")\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# * put abstract at the beginning\n",
    "def process_abstract(example):\n",
    "    text = example[\"text\"]\n",
    "    abstract_idx = text.rfind(\"Abstract: \")\n",
    "    abstract = text[abstract_idx:]\n",
    "    text = text[:abstract_idx]\n",
    "\n",
    "    return {\"text\": f\"{abstract}\\n\\n{text}\"}\n",
    "\n",
    "dataset = dataset.map(process_abstract, num_proc=32)\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# * split text to several chunk\n",
    "splitter = TextSplitter.from_tiktoken_model(\"gpt-4\", trim_chunks=False)\n",
    "\n",
    "def split_text(examples, indices):\n",
    "    result = {\n",
    "        \"text\": [],\n",
    "        \"index\": [],\n",
    "        \"section_index\": [],\n",
    "    }\n",
    "\n",
    "    for i in range(len(examples[\"text\"])):\n",
    "        text = examples[\"text\"][i]\n",
    "        chunks = splitter.chunks(text=text, chunk_capacity=4096)\n",
    "\n",
    "        result[\"text\"].extend(chunks)\n",
    "        result[\"index\"].extend([indices[i] for _ in chunks])\n",
    "        result[\"section_index\"].extend([i for i in range(len(chunks))])\n",
    "\n",
    "    return result\n",
    "\n",
    "chunked_dataset = dataset.map(split_text, with_indices=True, batched=True, num_proc=32, remove_columns=dataset.column_names)\n",
    "\n",
    "chunked_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \"\"\"Context information is below.\n",
    "---------------------\n",
    "${context}\n",
    "---------------------\n",
    "Given the context information and not prior knowledge.\n",
    "Generate content based on the below query.\n",
    "You are a Teacher/Professor. Your task is to setup 4 questions for an upcoming quiz/examination. The questions should be diverse in nature across the document. Restrict the questions to the context information provided.\n",
    "You must return the result in JSON: [{'question': <question>, 'answer': <answer>}, ..., {'question': <question>, 'answer': <answer>}]\"\"\"\n",
    "\n",
    "# * organize the data format\n",
    "jobs = []\n",
    "\n",
    "for data in tqdm(chunked_dataset):\n",
    "    prompt = Template(template).substitute(context=data[\"text\"])\n",
    "    jobs.append({\n",
    "        \"model\": \"gpt-35-turbo\", \n",
    "        \"temperature\": 0,\n",
    "        \"top_p\": 1.0,\n",
    "        \"max_tokens\": 4096,\n",
    "        \"messages\": [\n",
    "            {\"role\": \"user\", \"content\": prompt},\n",
    "        ],\n",
    "        \"user\": f\"{data['index']}-{data['section_index']}\",\n",
    "    })\n",
    "\n",
    "# * save, and then use Openai API script to generate data\n",
    "with open(\"data/one_detail.paper.long.chunk.jsonl\", \"w\") as f:\n",
    "    for job in jobs:\n",
    "        json_string = json.dumps(job)\n",
    "        f.write(json_string + \"\\n\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
